1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package com.google.common.collect;
18
19 import static com.google.common.base.Preconditions.checkArgument;
20 import static com.google.common.base.Preconditions.checkNotNull;
21 import static com.google.common.collect.CollectPreconditions.checkNonnegative;
22 import static com.google.common.collect.CollectPreconditions.checkRemove;
23
24 import com.google.common.annotations.GwtCompatible;
25 import com.google.common.annotations.GwtIncompatible;
26 import com.google.common.primitives.Ints;
27
28 import java.io.InvalidObjectException;
29 import java.io.ObjectStreamException;
30 import java.io.Serializable;
31 import java.util.ConcurrentModificationException;
32 import java.util.Iterator;
33 import java.util.Map;
34 import java.util.Set;
35
36 import javax.annotation.Nullable;
37
38
39
40
41
42
43
44
45
46
47 @GwtCompatible(emulated = true)
48 abstract class AbstractMapBasedMultiset<E> extends AbstractMultiset<E>
49 implements Serializable {
50
51 private transient Map<E, Count> backingMap;
52
53
54
55
56
57
58 private transient long size;
59
60
61 protected AbstractMapBasedMultiset(Map<E, Count> backingMap) {
62 this.backingMap = checkNotNull(backingMap);
63 this.size = super.size();
64 }
65
66
67 void setBackingMap(Map<E, Count> backingMap) {
68 this.backingMap = backingMap;
69 }
70
71
72
73
74
75
76
77
78
79
80 @Override
81 public Set<Multiset.Entry<E>> entrySet() {
82 return super.entrySet();
83 }
84
85 @Override
86 Iterator<Entry<E>> entryIterator() {
87 final Iterator<Map.Entry<E, Count>> backingEntries =
88 backingMap.entrySet().iterator();
89 return new Iterator<Multiset.Entry<E>>() {
90 Map.Entry<E, Count> toRemove;
91
92 @Override
93 public boolean hasNext() {
94 return backingEntries.hasNext();
95 }
96
97 @Override
98 public Multiset.Entry<E> next() {
99 final Map.Entry<E, Count> mapEntry = backingEntries.next();
100 toRemove = mapEntry;
101 return new Multisets.AbstractEntry<E>() {
102 @Override
103 public E getElement() {
104 return mapEntry.getKey();
105 }
106 @Override
107 public int getCount() {
108 Count count = mapEntry.getValue();
109 if (count == null || count.get() == 0) {
110 Count frequency = backingMap.get(getElement());
111 if (frequency != null) {
112 return frequency.get();
113 }
114 }
115 return (count == null) ? 0 : count.get();
116 }
117 };
118 }
119
120 @Override
121 public void remove() {
122 checkRemove(toRemove != null);
123 size -= toRemove.getValue().getAndSet(0);
124 backingEntries.remove();
125 toRemove = null;
126 }
127 };
128 }
129
130 @Override
131 public void clear() {
132 for (Count frequency : backingMap.values()) {
133 frequency.set(0);
134 }
135 backingMap.clear();
136 size = 0L;
137 }
138
139 @Override
140 int distinctElements() {
141 return backingMap.size();
142 }
143
144
145
146 @Override public int size() {
147 return Ints.saturatedCast(size);
148 }
149
150 @Override public Iterator<E> iterator() {
151 return new MapBasedMultisetIterator();
152 }
153
154
155
156
157
158
159 private class MapBasedMultisetIterator implements Iterator<E> {
160 final Iterator<Map.Entry<E, Count>> entryIterator;
161 Map.Entry<E, Count> currentEntry;
162 int occurrencesLeft;
163 boolean canRemove;
164
165 MapBasedMultisetIterator() {
166 this.entryIterator = backingMap.entrySet().iterator();
167 }
168
169 @Override
170 public boolean hasNext() {
171 return occurrencesLeft > 0 || entryIterator.hasNext();
172 }
173
174 @Override
175 public E next() {
176 if (occurrencesLeft == 0) {
177 currentEntry = entryIterator.next();
178 occurrencesLeft = currentEntry.getValue().get();
179 }
180 occurrencesLeft--;
181 canRemove = true;
182 return currentEntry.getKey();
183 }
184
185 @Override
186 public void remove() {
187 checkRemove(canRemove);
188 int frequency = currentEntry.getValue().get();
189 if (frequency <= 0) {
190 throw new ConcurrentModificationException();
191 }
192 if (currentEntry.getValue().addAndGet(-1) == 0) {
193 entryIterator.remove();
194 }
195 size--;
196 canRemove = false;
197 }
198 }
199
200 @Override public int count(@Nullable Object element) {
201 Count frequency = Maps.safeGet(backingMap, element);
202 return (frequency == null) ? 0 : frequency.get();
203 }
204
205
206
207
208
209
210
211
212
213
214 @Override public int add(@Nullable E element, int occurrences) {
215 if (occurrences == 0) {
216 return count(element);
217 }
218 checkArgument(
219 occurrences > 0, "occurrences cannot be negative: %s", occurrences);
220 Count frequency = backingMap.get(element);
221 int oldCount;
222 if (frequency == null) {
223 oldCount = 0;
224 backingMap.put(element, new Count(occurrences));
225 } else {
226 oldCount = frequency.get();
227 long newCount = (long) oldCount + (long) occurrences;
228 checkArgument(newCount <= Integer.MAX_VALUE,
229 "too many occurrences: %s", newCount);
230 frequency.getAndAdd(occurrences);
231 }
232 size += occurrences;
233 return oldCount;
234 }
235
236 @Override public int remove(@Nullable Object element, int occurrences) {
237 if (occurrences == 0) {
238 return count(element);
239 }
240 checkArgument(
241 occurrences > 0, "occurrences cannot be negative: %s", occurrences);
242 Count frequency = backingMap.get(element);
243 if (frequency == null) {
244 return 0;
245 }
246
247 int oldCount = frequency.get();
248
249 int numberRemoved;
250 if (oldCount > occurrences) {
251 numberRemoved = occurrences;
252 } else {
253 numberRemoved = oldCount;
254 backingMap.remove(element);
255 }
256
257 frequency.addAndGet(-numberRemoved);
258 size -= numberRemoved;
259 return oldCount;
260 }
261
262
263 @Override public int setCount(@Nullable E element, int count) {
264 checkNonnegative(count, "count");
265
266 Count existingCounter;
267 int oldCount;
268 if (count == 0) {
269 existingCounter = backingMap.remove(element);
270 oldCount = getAndSet(existingCounter, count);
271 } else {
272 existingCounter = backingMap.get(element);
273 oldCount = getAndSet(existingCounter, count);
274
275 if (existingCounter == null) {
276 backingMap.put(element, new Count(count));
277 }
278 }
279
280 size += (count - oldCount);
281 return oldCount;
282 }
283
284 private static int getAndSet(Count i, int count) {
285 if (i == null) {
286 return 0;
287 }
288
289 return i.getAndSet(count);
290 }
291
292
293 @GwtIncompatible("java.io.ObjectStreamException")
294 @SuppressWarnings("unused")
295 private void readObjectNoData() throws ObjectStreamException {
296 throw new InvalidObjectException("Stream data required");
297 }
298
299 @GwtIncompatible("not needed in emulated source.")
300 private static final long serialVersionUID = -2250766705698539974L;
301 }